Skip to content

[WIP,POC] Faster functional modules #983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

vmoens
Copy link

@vmoens vmoens commented Jul 24, 2022

Proposes a new method to load weights in FunctionalModule and FunctionalModuleWithBuffers.

A map module <-> param_name <-> param_value is created and used to set attributes.

Test:
The following test runs twice as fast on CPU than current implementation:

if __name__ == "__main__":
    # module with high param allocation cost but few operations
    net = torch.nn.Sequential(
        torch.nn.Linear(1, 1),
        torch.nn.Linear(1, 1),
        torch.nn.Sequential(
            torch.nn.Linear(1, 1),
            torch.nn.Linear(1, 1),
            torch.nn.Linear(1, 1),
            torch.nn.Linear(1, 1),
        )
    )

    fnet, params = make_functional(net)
    x = torch.randn(1)
    print(timeit.timeit("fnet(params, x)", globals={"fnet": fnet, "x": x, "params": params}, number=10000))
    # 1.7 sec with new, 3.8 with old
    

    # the implementation supports serialization
    import tempfile
    with tempfile.NamedTemporaryFile() as file:
        torch.save(fnet2, file.name)
        loaded_fnet = torch.load(file.name)
        assert torch.isclose(fnet2(params, x), loaded_fnet(params, x))

Other metrics:
On torchrl's DDPG, the new in a full forward-backard pass, the old implementation of _swap_state takes approx. 20% of the runtime with small neural nets (2 layers MLP with 256 cells) on CPU. The new implementation takes approx. 6% of runtime.

@zou3519 zou3519 self-requested a review July 25, 2022 21:39
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the approach -- it's awesome that it speeds up small neural nets a lot. I'm curious about your thoughts on parameter tying

Comment on lines +59 to +63
for module_name, m in model.named_modules():
for param_name, p in list(m.named_parameters(recurse=False)):
delattr(m, param_name)
setattr(m, param_name, None)
yield (module_name, m, param_name, p)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we previously used create_names_map was for parameter tying. If someone creates a module that looks like:

class Foo(nn.Module):
   def __init__(self):
        super().__init__()
        self.bias = nn.Parameter(torch.randn(3))
        self.linear = nn.Linear(3, 3)
        self.linear.bias = self.bias

then fmodel, params = make_functional(Foo()) returns 2 Tensors (self.linear.weight and self.bias) instead of 3 Tensors. When the user calls fmodel([w, b], x), then b gets loaded to self.bias and self.linear.bias and w gets loaded to self.linear.weight.

Under the new strategy, it seems like params would have 3 tensors: [self.bias, self.linear.weight, self.linear.bias].

In general I'm not really sure what the interaction between parameter tying and make_functional should be. Thoughts?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. If we want to keep things as they are we could link params to a list of modules and a list of names (instead of a single module and a single name). That will come with a slight overhead though...

It's the kind of design choice where you will always make someone unhappy (there will be someone out there that wants multiple copies of the same param), but it's probably not the majority of users.

Comment on lines +255 to +262
old_states = _swap_state(
self.param_modules + self.buffer_modules,
self.param_names + self.buffer_names,
list(params) + list(buffers)
)
old_params = old_states[:len(self.param_modules)]
old_buffers = old_states[len(self.param_modules):]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I understand why this is faster: is it because we no longer need to traverse through the module to find the submodules; we've already made the submodules directly available to swap their parameters out?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly!
Instead of going through a tree of param names, we just flatten it and go through a single list of modules, one-level names and values.

Comment on lines +229 to +231
param_module_names, param_modules, param_names, params = zip(*param_container)
else:
param_module_names, param_modules, param_names, params = tuple(), tuple(), tuple(), tuple()
Copy link
Contributor

@zou3519 zou3519 Jul 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, we guaranteed that params is returned in the same order as what gets returned by original_model.parameters(). After this change, is that still true?

(Side note) To be honest, we've been thinking of changing the API so that params isn't returned as a flat list; instead we probably want to return some sort of dictionary or object so that one can easily figure out which params corresponds to which parameters on the original module. This is something that a couple of users have asked us for. If we returned a dictionary then it doesn't matter that params isn't the same as what gets returned by original_module.parameters()

Copy link
Author

@vmoens vmoens Aug 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh man this would be so great! I could definitely use that feature.

To be honest, I was thinking about using TensorDict from torchrl to pass params to functorch stateless modules. We could nest the dicts (eg d["module"]["param"] to d["module.param"]), expand the params, change device or whatever, in batch and with little or no effort since all those ops are built-in tensordict methods. I think there's a good synergy that we could get from TensorDict functorch. At the moment, TensorDict isn't torchscriptable though, I don't know how much trouble it is for you.
@nairbv @shagunsodhani

@facebook-github-bot
Copy link

Hi @vmoens!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants